import numpy as np  # NumPy是一个用于科学计算的基础包,用于处理大型多维数组和矩阵
import faiss  # FAISS库用于高效的相似度搜索和稠密向量的聚类

# 定义FaissKNeighbors类，用于执行基于FAISS的K近邻搜索
class FaissKNeighbors:
    # 类初始化函数：初始化k值，FAISS资源对象res，以及用于存储数据的索引
    def __init__(self, k=1, res=None):
        self.index = None  # 用于存储训练数据的索引
        self.y = None  # 用于存储训练数据的标签
        self.k = int(k)  # 最近邻个数
        self.res = res  # FAISS GPU资源对象

    # 训练函数：将训练数据加入到FAISS索引中
    def fit(self, X, y):
        X = np.asarray(X, dtype=np.float32)
        y = np.asarray(y)

        # 使用欧氏距离的平面索引（CPU）
        d = X.shape[1]
        cpu_index = faiss.IndexFlatL2(d)

        # 如提供 GPU 资源，则迁移到 GPU（device 0）
        if self.res is not None:
            self.index = faiss.index_cpu_to_gpu(self.res, 0, cpu_index)
        else:
            self.index = cpu_index

        # 添加向量并保存标签
        self.index.add(X)
        self.y = y.astype(np.int64, copy=False)
        return self

    # 预测函数：对新的数据集X进行分类预测
    def predict(self, X):
        if self.index is None or self.y is None:
            raise ValueError("FaissKNeighbors 未训练，请先调用 fit().")

        X = np.asarray(X, dtype=np.float32)
        # 搜索X中每个向量的k个最近邻
        distances, indices = self.index.search(X, self.k)
        votes = self.y[indices]  # 根据索引获得最近邻的标签
        # 通过投票机制得出最终预测的标签
        predictions = np.array([np.argmax(np.bincount(vote)) for vote in votes])
        return predictions

    # 评分函数：计算预测准确率
    def score(self, X, y):
        y = np.asarray(y)
        preds = self.predict(X)
        accuracy = np.mean(preds == y)
        return float(accuracy)
